{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:35:56.628130Z",
     "start_time": "2019-03-20T09:35:45.661384Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 2.1.0\n",
      "\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "loading embedding ...\n",
      "embedding loaded as `glove_embedding`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:35:56.633000Z",
     "start_time": "2019-03-20T09:35:56.630450Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=100, remove_stop_words=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:36:06.249211Z",
     "start_time": "2019-03-20T09:35:56.634788Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9354.30it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5242.39it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 959307.59it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 144447.70it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 151024.12it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 818306.55it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 742442.92it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 404432/404432 [00:00<00:00, 2591651.03it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 9654.46it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:03<00:00, 5546.36it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 138535.63it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 254476.95it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 125907.17it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 631516.02it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 868328.96it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 138403.01it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 61442.21it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9297.98it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 5474.46it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 123108.59it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 109807.96it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 142493.87it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 193974.64it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 688349.86it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 110281.27it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 86857.14it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 9041.67it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 4067.08it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 123996.13it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 143565.86it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 129280.34it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 309383.77it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 887315.97it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 101288.98it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 85815.05it/s]\n"
     ]
    }
   ],
   "source": [
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "dev_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:36:06.262937Z",
     "start_time": "2019-03-20T09:36:06.253350Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x1115c2f28>,\n",
       " 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x13ec386d8>,\n",
       " 'vocab_size': 16674,\n",
       " 'embedding_input_dim': 16674,\n",
       " 'input_shapes': [(10,), (100,)]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:36:06.413530Z",
     "start_time": "2019-03-20T09:36:06.267256Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model_class                   <class 'matchzoo.models.drmmtks.DRMMTKS'>\n",
      "input_shapes                  [(10,), (100,)]\n",
      "task                          Ranking Task\n",
      "optimizer                     adadelta\n",
      "with_embedding                True\n",
      "embedding_input_dim           16674\n",
      "embedding_output_dim          100\n",
      "embedding_trainable           True\n",
      "with_multi_layer_perceptron   True\n",
      "mlp_num_units                 5\n",
      "mlp_num_layers                1\n",
      "mlp_num_fan_out               1\n",
      "mlp_activation_func           relu\n",
      "mask_value                    -1\n",
      "top_k                         20\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.DRMMTKS()\n",
    "\n",
    "# load `input_shapes` and `embedding_input_dim` (vocab_size)\n",
    "model.params.update(preprocessor.context)\n",
    "\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mask_value'] = -1\n",
    "model.params['embedding_output_dim'] = glove_embedding.output_dim\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['top_k'] = 20\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 5\n",
    "model.params['mlp_num_fan_out'] = 1\n",
    "model.params['mlp_activation_func'] = 'relu'\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "\n",
    "model.build()\n",
    "model.compile()\n",
    "\n",
    "print(model.params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:36:06.422264Z",
     "start_time": "2019-03-20T09:36:06.415605Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 100)          0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             1667400     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 10, 100)      0           embedding[0][0]                  \n",
      "                                                                 embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 10, 1)        100         embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, 10, 10)       0           dot_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "attention_mask (Lambda)         (None, 10, 1)        0           dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 10, 5)        55          lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "attention_probs (Lambda)        (None, 10, 1)        0           attention_mask[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 10, 1)        6           dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_2 (Dot)                     (None, 1, 1)         0           attention_probs[0][0]            \n",
      "                                                                 dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 1)            0           dot_2[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 1)            2           flatten_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 1,667,563\n",
      "Trainable params: 1,667,563\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:37:59.341616Z",
     "start_time": "2019-03-20T09:36:06.425086Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "term_index = preprocessor.context['vocab_unit'].state['term_index']\n",
    "embedding_matrix = glove_embedding.build_matrix(term_index)\n",
    "l2_norm = np.sqrt((embedding_matrix * embedding_matrix).sum(axis=1))\n",
    "embedding_matrix = embedding_matrix / l2_norm[:, np.newaxis]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:37:59.432719Z",
     "start_time": "2019-03-20T09:37:59.343309Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:37:59.489842Z",
     "start_time": "2019-03-20T09:37:59.434509Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_x, test_y = test_pack_processed.unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=test_x, y=test_y, batch_size=len(test_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:38:47.716275Z",
     "start_time": "2019-03-20T09:38:37.602828Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num batches: 408\n"
     ]
    }
   ],
   "source": [
    "train_generator = mz.DataGenerator(\n",
    "    train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=2,\n",
    "    num_neg=1,\n",
    "    batch_size=20\n",
    ")\n",
    "print('num batches:', len(train_generator))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:50:55.901154Z",
     "start_time": "2019-03-20T09:38:59.434426Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "408/408 [==============================] - 7s 17ms/step - loss: 0.5841\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5988918016460031 - normalized_discounted_cumulative_gain@5(0.0): 0.6590526112135905 - mean_average_precision(0.0): 0.6149729212603963\n",
      "Epoch 2/30\n",
      "408/408 [==============================] - 16s 40ms/step - loss: 0.1286\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5986225503033814 - normalized_discounted_cumulative_gain@5(0.0): 0.6573151832954626 - mean_average_precision(0.0): 0.6081435794322055\n",
      "Epoch 3/30\n",
      "408/408 [==============================] - 17s 41ms/step - loss: 0.0236\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5770594699106932 - normalized_discounted_cumulative_gain@5(0.0): 0.642764019599169 - mean_average_precision(0.0): 0.5915165228479636\n",
      "Epoch 4/30\n",
      "408/408 [==============================] - 20s 48ms/step - loss: 0.0085\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5630076023551297 - normalized_discounted_cumulative_gain@5(0.0): 0.6314235897212269 - mean_average_precision(0.0): 0.5811220052309326\n",
      "Epoch 5/30\n",
      "408/408 [==============================] - 21s 52ms/step - loss: 0.0039\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5721875341813376 - normalized_discounted_cumulative_gain@5(0.0): 0.6430346137953176 - mean_average_precision(0.0): 0.5887684999840622\n",
      "Epoch 6/30\n",
      "408/408 [==============================] - 18s 44ms/step - loss: 0.0025\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5635600485727307 - normalized_discounted_cumulative_gain@5(0.0): 0.636814704018869 - mean_average_precision(0.0): 0.5831658911191673\n",
      "Epoch 7/30\n",
      "408/408 [==============================] - 19s 47ms/step - loss: 0.0019\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5624051725623126 - normalized_discounted_cumulative_gain@5(0.0): 0.6362646813359831 - mean_average_precision(0.0): 0.5830741350566369\n",
      "Epoch 8/30\n",
      "408/408 [==============================] - 20s 48ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5627439034097886 - normalized_discounted_cumulative_gain@5(0.0): 0.6330671467013023 - mean_average_precision(0.0): 0.5806811015299475\n",
      "Epoch 9/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0018\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5679681248986371 - normalized_discounted_cumulative_gain@5(0.0): 0.6363760316713785 - mean_average_precision(0.0): 0.5851179931317684\n",
      "Epoch 10/30\n",
      "408/408 [==============================] - 22s 54ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5745461650682577 - normalized_discounted_cumulative_gain@5(0.0): 0.6407437747766428 - mean_average_precision(0.0): 0.5893184033052796\n",
      "Epoch 11/30\n",
      "408/408 [==============================] - 21s 50ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5684307665143323 - normalized_discounted_cumulative_gain@5(0.0): 0.634110297384481 - mean_average_precision(0.0): 0.5844463619617193\n",
      "Epoch 12/30\n",
      "408/408 [==============================] - 20s 48ms/step - loss: 0.0018\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.566182378404992 - normalized_discounted_cumulative_gain@5(0.0): 0.6318466678703829 - mean_average_precision(0.0): 0.583959239462311\n",
      "Epoch 13/30\n",
      "408/408 [==============================] - 21s 50ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5671372070359241 - normalized_discounted_cumulative_gain@5(0.0): 0.6288034745152886 - mean_average_precision(0.0): 0.5823057242012007\n",
      "Epoch 14/30\n",
      "408/408 [==============================] - 24s 59ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5688445292639432 - normalized_discounted_cumulative_gain@5(0.0): 0.6289506917256232 - mean_average_precision(0.0): 0.5823084213744116\n",
      "Epoch 15/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0016\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5616594491502604 - normalized_discounted_cumulative_gain@5(0.0): 0.6274555733102793 - mean_average_precision(0.0): 0.5803460611842033\n",
      "Epoch 16/30\n",
      "408/408 [==============================] - 24s 59ms/step - loss: 0.0019\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.56743611685671 - normalized_discounted_cumulative_gain@5(0.0): 0.6279655493568921 - mean_average_precision(0.0): 0.5833285217616007\n",
      "Epoch 17/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0019\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.56743611685671 - normalized_discounted_cumulative_gain@5(0.0): 0.629412929341684 - mean_average_precision(0.0): 0.5831775027001493\n",
      "Epoch 18/30\n",
      "408/408 [==============================] - 21s 53ms/step - loss: 0.0017\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5616594491502604 - normalized_discounted_cumulative_gain@5(0.0): 0.6230767811708631 - mean_average_precision(0.0): 0.5778751707969239\n",
      "Epoch 19/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0014\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5624755943132025 - normalized_discounted_cumulative_gain@5(0.0): 0.6196682608534165 - mean_average_precision(0.0): 0.5773872198947884\n",
      "Epoch 20/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0017\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.556962625552094 - normalized_discounted_cumulative_gain@5(0.0): 0.6181385580129402 - mean_average_precision(0.0): 0.5762216174819781\n",
      "Epoch 21/30\n",
      "408/408 [==============================] - 22s 53ms/step - loss: 0.0018\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5532956624869947 - normalized_discounted_cumulative_gain@5(0.0): 0.614101122041995 - mean_average_precision(0.0): 0.570779766786963\n",
      "Epoch 22/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0019\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5580675179872963 - normalized_discounted_cumulative_gain@5(0.0): 0.615310109292882 - mean_average_precision(0.0): 0.5723212571524482\n",
      "Epoch 23/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0017\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.555957813345946 - normalized_discounted_cumulative_gain@5(0.0): 0.6168348059854845 - mean_average_precision(0.0): 0.5745875152731366\n",
      "Epoch 24/30\n",
      "408/408 [==============================] - 20s 50ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5545892219654028 - normalized_discounted_cumulative_gain@5(0.0): 0.6161692037842695 - mean_average_precision(0.0): 0.5737084716725741\n",
      "Epoch 25/30\n",
      "408/408 [==============================] - 21s 51ms/step - loss: 0.0015\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5572513728243542 - normalized_discounted_cumulative_gain@5(0.0): 0.6149353178966624 - mean_average_precision(0.0): 0.5726138530937782\n",
      "Epoch 26/30\n",
      "408/408 [==============================] - 20s 50ms/step - loss: 0.0019\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5530319635416537 - normalized_discounted_cumulative_gain@5(0.0): 0.6157477651673351 - mean_average_precision(0.0): 0.5719727220729004\n",
      "Epoch 27/30\n",
      "408/408 [==============================] - 22s 55ms/step - loss: 0.0018\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5535491988838099 - normalized_discounted_cumulative_gain@5(0.0): 0.6157701973864509 - mean_average_precision(0.0): 0.5706400633221479\n",
      "Epoch 28/30\n",
      "408/408 [==============================] - 24s 58ms/step - loss: 0.0017\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.556713699306524 - normalized_discounted_cumulative_gain@5(0.0): 0.6164860410473986 - mean_average_precision(0.0): 0.5726880826961166\n",
      "Epoch 29/30\n",
      "408/408 [==============================] - 25s 61ms/step - loss: 0.0015\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5586585674438334 - normalized_discounted_cumulative_gain@5(0.0): 0.6187911692779081 - mean_average_precision(0.0): 0.5763274637089153\n",
      "Epoch 30/30\n",
      "408/408 [==============================] - 24s 58ms/step - loss: 0.0014\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5642215197801578 - normalized_discounted_cumulative_gain@5(0.0): 0.6253202724847128 - mean_average_precision(0.0): 0.5820402362707493\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=4, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "Use this function to update the README.md with a better set of parameters.\n",
    "Make sure you delete the correct section of the README.md before calling this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-03-20T09:38:03.829486Z",
     "start_time": "2019-03-20T09:35:45.706Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# append_params_to_readme(model)"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "matchzoo",
   "language": "python",
   "name": "matchzoo"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
